

import argparse
import builtins
import math
import os
import random
import time
import warnings
import pprint
import logging
import torch.multiprocessing as mp
from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import random
import itertools
import scipy.io as scio
import sys
from datasets import *
import models
from tqdm import tqdm

from sklearn.model_selection import ParameterGrid

parser = argparse.ArgumentParser(description='PyTorch Semi-supervised Node Classification')

# dataset
parser.add_argument('-d', '--dataset', metavar='DATASET', default='cora', type=str)
parser.add_argument('--data-path', metavar='DATAPATH', default='./data', type=str)
parser.add_argument('--num-labels', default=1, type=int, help='the number of labels per class')
parser.add_argument('--num-val', default=20, type=int, help='the number of labels per class')

# network
parser.add_argument('-a', '--arch', metavar='ARCH', default='gcn', type=str)
parser.add_argument('--lamb', default=1.0, type=float)
parser.add_argument('--scale', default=1.0, type=float)
parser.add_argument('--n-hidden', default=128, type=int)

# training settings
parser.add_argument('--epochs', default=500, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--num-trials', default=100, type=int)

# optimization settings
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial (base) learning rate', dest='lr')
parser.add_argument('--wd', '--weight-decay', default=1e-2, type=float, metavar='W', help='weight decay (default: 1e-2)', dest='wd')
parser.add_argument('--dropout', default=0.8, type=float)
parser.add_argument('--T', default=10, type=float)

# others
parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ')
parser.add_argument('--gpu', default=0, type=int, help='GPU id to use.')
parser.add_argument('--save', default='./results', type=str)


max_val = 0.0
max_test = 0.0


def get_logger(args):
    head = '%(asctime)-15s %(message)s'
    logging.basicConfig(filename=os.path.join(args.save, 'log.txt'), format=head)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    console = logging.StreamHandler()
    logging.getLogger('').addHandler(console)
    return logger

def main(args):
    if args.seed is not None:
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        random.seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                    'This will turn on the CUDNN deterministic setting, '
                    'which can slow down your training considerably! '
                    'You may see unexpected behavior when restarting '
                    'from checkpoints.')
    
    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU.')
        torch.cuda.set_device(args.gpu)
    
    adj, features, labels, idx_train, idx_val, idx_test = load_data_1(args.dataset, args.num_labels, args.num_val)
    num_classes = int(labels.max()) + 1
    edge_index = adj.coalesce().indices()
    edge_index = DelDiagEdgeIndex(edge_index)
    if args.arch == 'mlp':
        model = models.MLP(n_features=features.shape[1],
                        n_hidden=args.n_hidden,
                        n_classes=num_classes,
                        dropout=args.dropout)
    if args.arch == 'gcn':
        model = models.GCN(n_features=features.shape[1],
                        n_hidden=args.n_hidden,
                        n_classes=num_classes,
                        dropout=args.dropout)
    elif args.arch == 'gat':
        model = models.GAT(n_features=features.shape[1],
                        n_hidden=args.n_hidden,
                        n_classes=num_classes,
                        dropout=args.dropout)
    elif args.arch == 'gpn':
        model = models.GPN(features=features.cuda(),
                            n_hidden=args.n_hidden,
                            edge_indices_no_diag=edge_index.cuda(),
                            idx_train=idx_train.cuda(),
                            labels=labels.cuda(),
                            leaky_rate=0.1,
                            adj=[False, 1],
                            dropout=args.dropout,
                            T=args.T)
    elif args.arch == 'vgpn':
        model = models.VGPN(features=features.cuda(),
                            n_hidden=args.n_hidden,
                            edge_indices_no_diag=edge_index.cuda(),
                            idx_train=idx_train.cuda(),
                            labels=labels.cuda(),
                            leaky_rate=0.1,
                            adj=[False, 1],
                            dropout=args.dropout,
                            T=args.T,
                            lamb=args.lamb)
    
    model.cuda()
    idx_val_test = torch.cat([idx_val, idx_test])

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)

    max_val = 0
    max_test = 0
    for epoch in range(args.epochs):
        model.train()
        optimizer.zero_grad()
        if args.arch in ['mlp', 'gpn', 'vgpn']:
            output = model(features.cuda())
        elif args.arch == 'gcn':
            output = model(features.cuda(), adj.cuda())
        elif args.arch == 'gat':
            output = model(features.cuda(), edge_index.cuda())
        loss = F.nll_loss(F.log_softmax(args.scale * output[idx_train, :], dim=1), labels[idx_train].cuda().long())
        
        
        loss.backward()
        optimizer.step()
        
        model.eval()
        if args.arch in ['mlp', 'gpn', 'vgpn']:
            output = model(features.cuda())
        elif args.arch == 'gcn':
            output = model(features.cuda(), adj.cuda())
        elif args.arch == 'gat':
            output = model(features.cuda(), edge_index.cuda())
        preds = torch.argmax(output, dim=1)
        acc_train = torch.sum(preds[idx_train] == labels.cuda()[idx_train]).float() / labels[idx_train].shape[0]
        acc_test = torch.sum(preds[idx_test] == labels.cuda()[idx_test]).float() / labels[idx_test].shape[0]
        acc_val = torch.sum(preds[idx_val] == labels.cuda()[idx_val]).float() / labels[idx_val].shape[0]
        if acc_val > max_val:
            max_val = acc_val
            max_test = acc_test.detach().clone().cpu().numpy()
    torch.cuda.empty_cache()
    return max_test * 100

def main_worker(args, gpu):
    args.gpu = gpu
    accs = []
    for i in range(args.num_trials):
        args.seed = i
        acc = main(args)
        accs.append(acc)
    return np.array(accs)

if __name__ == "__main__":
    args = parser.parse_args()
    args.save = os.path.join(args.save, '_'.join([str(i) for i in [args.dataset, args.arch]]))
    os.makedirs(args.save, exist_ok=True)
    logger = get_logger(args)
    logger.info('\n' + pprint.pformat(args))
    if args.arch in ['mlp', 'gcn', 'gpn', 'gat']:
        params = {
            'lr': [0.1, 0.05, 0.01],
            'wd': [1e-2, 5e-3, 1e-3, 5e-4],
            'scale': [1.0, 5.0],
            'x_num_labels': [1, 2, 3, 4, 5]
        }
    elif args.arch == 'vgpn':
        params = {
            'lr': [0.1, 0.05, 0.01],
            'wd': [1e-2, 5e-3, 1e-3, 5e-4],
            'lamb': [1.0, 0.5],
            'scale': [1.0, 5.0],
            'x_num_labels': [1, 2, 3, 4, 5]
        }
    grid = ParameterGrid(params)
    ctx = torch.multiprocessing.get_context("spawn")
    NUM_PROCESSING = 24
    pool = ctx.Pool(NUM_PROCESSING)
    
    pool_list = []
    gpu = 0
    print(len(grid))
    for i, param in enumerate(grid):
        targs = deepcopy(args)
        targs.lr = param['lr']
        targs.wd = param['wd']
        args.scale = param['scale']
        targs.num_labels = param['x_num_labels']
        if args.arch == 'vgpn':
            targs.lamb = param['lamb']
        res = pool.apply_async(main_worker, args=(targs, i % 4 + 1))
        pool_list.append(res)
    pool.close()
    pool.join()

    results = None
    for param, res in zip(grid, pool_list):
        accs = res.get().reshape(1, -1)
        if results is None:
            results = accs
        else:
            results = np.r_[results, accs]
        # accs = accs[accs>20]
        logger.info(str(param) + ': %.1f (%.1f)' % (accs.mean(), accs.std()))
    
    np.savetxt(os.path.join(args.save, 'results.txt'), results, fmt='%.2f', delimiter=',')